#!/bin/bash
# Starts Stage 2 training - fine-tunes full decoder and adapter with high masking ratios
# to specialize in initial token generation
#
# Usage:
# 1. chmod +x scripts/04_train_stage2_decoder_high_ratio.sh
# 2. ./scripts/04_train_stage2_decoder_high_ratio.sh

set -e

echo "Starting Stage 2: Decoder Specialization (High Mask Ratio)..."

python -m src.training.train_stage2_decoder_high_ratio \
   --train_data_dir data/processed/train-clean-100 data/processed/train-clean-360 data/processed/train-other-500 \
   --val_data_dir   data/processed/dev-clean data/processed/dev-other \
   --pretrain_path  out/stage1_adapter_960h/ft-Diff_LLaMA_170M-1753116650/adapter_best.pt \
   --base_model_path pretrained_models/mdm_safetensors/mdm-170M-100e18-rsl-0.01.safetensors \
   --out_dir        out/stage2_decoder_960h_high_ratio \
   --model_name     Diff_LLaMA_170M \
   --tokenizer_name TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
   --num_devices 4              \
   --batch_size 128             \
   --gradient_accumulation_steps 2 \
   --learning_rate 1e-5         \
   --second_stage_lr_multiplier 0.5 \
   --lr_scaling linear          \
   --use_layer_wise_lr_decay    \
   --layer_wise_lr_decay_rate 0.9 \
   --weight_decay 0.005         \
   --scheduler_type cosine      \
   --warmup_ratio 0.1           \
   --epochs 30                  \
   --patience 5                 \
   --use_ema \
   --ema_decay 0.995  \
   --compute_wer_cer   \
   --min_mask_ratio 0.7 \
   --max_mask_ratio 1.0

echo "✅ Stage 2 training script finished."